{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Physionet 2017 | ECG Rhythm Classification\n",
    "## 4. Train Model (CPU Test)\n",
    "### Sebastian D. Goodfellow, Ph.D."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "# Setup Noteboook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import 3rd party libraries\n",
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "# Deep learning libraries\n",
    "import tensorflow as tf\n",
    "\n",
    "# Import local Libraries\n",
    "sys.path.insert(0, r'C:\\Users\\sebig\\Documents\\code\\deep_ecg')\n",
    "from utils.plotting.time_series import plot_time_series_widget\n",
    "from utils.data.labels.one_hot_encoding import one_hot_encoding\n",
    "from utils.devices.device_check import print_device_counts, get_device_names\n",
    "from train.train import train\n",
    "from model.model import Model\n",
    "\n",
    "# Configure Notebook\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Load ECG Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set path\n",
    "path = os.path.join(os.path.dirname(os.getcwd()), 'data', 'training')\n",
    "\n",
    "# Set sample rate\n",
    "fs = 300\n",
    "\n",
    "# Unpickle\n",
    "with open(os.path.join(path, 'training_60s.pickle'), \"rb\") as input_file:\n",
    "    data = pickle.load(input_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train dimensions: (5774, 18000, 1)\n",
      "y_train dimensions: (5774, 1)\n",
      "x_val dimensions: (2475, 18000, 1)\n",
      "y_val dimensions: (2475, 1)\n"
     ]
    }
   ],
   "source": [
    "# Get training data\n",
    "x_train = data['data_train'].values.reshape(data['data_train'].shape[0], data['data_train'].shape[1], 1)\n",
    "y_train = data['labels_train']['label_int'].values.reshape(data['labels_train'].shape[0], 1).astype(int)\n",
    "\n",
    "# Get validation data\n",
    "x_val = data['data_val'].values.reshape(data['data_val'].shape[0], data['data_val'].shape[1], 1)\n",
    "y_val = data['labels_val']['label_int'].values.reshape(data['labels_val'].shape[0], 1).astype(int)\n",
    "\n",
    "# Print dimensions\n",
    "print('x_train dimensions: ' + str(x_train.shape))\n",
    "print('y_train dimensions: ' + str(y_train.shape))\n",
    "print('x_val dimensions: ' + str(x_val.shape))\n",
    "print('y_val dimensions: ' + str(y_val.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train dimensions: (5774, 18000, 1)\n",
      "y_train dimensions: (5774, 1)\n",
      "y_train_1hot dimensions: (5774, 3)\n",
      "x_val dimensions: (2475, 18000, 1)\n",
      "y_val dimensions: (2475, 1)\n",
      "y_val_1hot dimensions: (2475, 3)\n"
     ]
    }
   ],
   "source": [
    "# One hot encoding array dimensions\n",
    "y_train_1hot = one_hot_encoding(labels=y_train.ravel(), classes=len(np.unique(y_train.ravel())))\n",
    "y_val_1hot = one_hot_encoding(labels=y_val.ravel(), classes=len(np.unique(y_val.ravel())))\n",
    "\n",
    "# Print dimensions\n",
    "print('x_train dimensions: ' + str(x_train.shape))\n",
    "print('y_train dimensions: ' + str(y_train.shape))\n",
    "print('y_train_1hot dimensions: ' + str(y_train_1hot.shape))\n",
    "print('x_val dimensions: ' + str(x_val.shape))\n",
    "print('y_val dimensions: ' + str(y_val.shape))\n",
    "print('y_val_1hot dimensions: ' + str(y_val_1hot.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train: Classes: [0 1 2]\n",
      "Train: Count: [3553  531 1690]\n",
      "Val: Classes: [0 1 2]\n",
      "Val: Count: [1523  227  725]\n"
     ]
    }
   ],
   "source": [
    "# Label lookup\n",
    "label_lookup = {'N': 0, 'A': 1, 'O': 2, '~': 3}\n",
    "\n",
    "# Label dimensions\n",
    "print('Train: Classes: ' + str(np.unique(y_train.ravel())))\n",
    "print('Train: Count: ' + str(np.bincount(y_train.ravel())))\n",
    "print('Val: Classes: ' + str(np.unique(y_val.ravel())))\n",
    "print('Val: Count: ' + str(np.bincount(y_val.ravel())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "edbbe07d58be40b2b42df281cc985be7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(IntSlider(value=2886, description='index', max=5773), Output()), _dom_classes=('widget-i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Label dictionary\n",
    "label_list = ['Normal Sinus Rhythm', 'Atrial Fibrillation', 'Other Rhythm']\n",
    "\n",
    "# PLot times series\n",
    "plot_time_series_widget(time_series=x_train, labels=y_train, fs=fs, label_list=label_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Device Check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Workstation has 1 CPUs.\n",
      "Workstation has 0 GPUs.\n"
     ]
    }
   ],
   "source": [
    "# Get GPU count\n",
    "print_device_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Set save path for graphs, summaries, and checkpoints\n",
    "save_path = r'C:\\Users\\sebig\\Desktop\\tensorboard\\deep_ecg\\test'\n",
    "\n",
    "# Set model name\n",
    "model_name = 'test_10'\n",
    "\n",
    "# Maximum number of checkpoints to keep\n",
    "max_to_keep = 20\n",
    "\n",
    "# Set randome states\n",
    "seed = 0                                    \n",
    "tf.set_random_seed(seed)                      \n",
    "\n",
    "# Get training dataset dimensions\n",
    "(m, length, channels) = x_train.shape  \n",
    "\n",
    "# Get number of label classes\n",
    "classes = y_train_1hot.shape[1]                  \n",
    "\n",
    "# Choose network\n",
    "network_name = 'DeepECG'\n",
    "\n",
    "# Set network inputs\n",
    "network_parameters = dict(\n",
    "    length=length,\n",
    "    channels=channels, \n",
    "    classes=classes, \n",
    "    seed=seed,  \n",
    ")\n",
    "\n",
    "# Create model\n",
    "model = Model(\n",
    "    model_name=model_name, \n",
    "    network_name=network_name, \n",
    "    network_parameters=network_parameters, \n",
    "    save_path=save_path,\n",
    "    max_to_keep=max_to_keep\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# 7. Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Set hyper-parameters\n",
    "epochs = 10\n",
    "minibatch_size =  1\n",
    "learning_rate = 0.001            \n",
    "\n",
    "# Train model\n",
    "train(model=model, x_train=x_train[0:1], y_train=y_train_1hot[0:1], x_val=x_val[0:1], y_val=y_val_1hot[0:1],\n",
    "      learning_rate=learning_rate, epochs=epochs, mini_batch_size=minibatch_size)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
